{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classification Analysis for ECG Time-Series\n",
    "\n",
    "> Copyright 2019 Dave Fernandes. All Rights Reserved.\n",
    "> \n",
    "> Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "> you may not use this file except in compliance with the License.\n",
    "> You may obtain a copy of the License at\n",
    ">\n",
    "> http://www.apache.org/licenses/LICENSE-2.0\n",
    ">  \n",
    "> Unless required by applicable law or agreed to in writing, software\n",
    "> distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "> See the License for the specific language governing permissions and\n",
    "> limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "This notebook classifies time-series for segmented heartbeats from ECG lead II recordings. Either of two models (CNN or RNN) can be selected from training and evaluation.\n",
    "- Data for this analysis should be prepared using the `PreprocessECG.ipynb` notebook from this project.\n",
    "- Original data is from: https://www.kaggle.com/shayanfazeli/heartbeat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow.keras.layers as keras\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "\n",
    "tf.enable_eager_execution()\n",
    "\n",
    "TRAIN_SET = './Data/train_set.pickle'\n",
    "TEST_SET = './Data/test_set.pickle'\n",
    "\n",
    "with open(TEST_SET, 'rb') as file:\n",
    "    test_set = pickle.load(file)\n",
    "    x_test = test_set['x']\n",
    "    y_test = test_set['y']\n",
    "\n",
    "with open(TRAIN_SET, 'rb') as file:\n",
    "    train_set = pickle.load(file)\n",
    "    x_train = train_set['x']\n",
    "    y_train = train_set['y']\n",
    "    \n",
    "def parameter_count():\n",
    "    total = 0\n",
    "    for v in tf.trainable_variables():\n",
    "        v_elements = 1\n",
    "        for dim in v.get_shape():\n",
    "            v_elements *= dim.value\n",
    "\n",
    "        total += v_elements\n",
    "    return total"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Input functions for Estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def combined_dataset(features, labels):\n",
    "    assert features.shape[0] == labels.shape[0]\n",
    "    dataset = tf.data.Dataset.from_tensor_slices(({'time_series': features}, labels))\n",
    "    return dataset\n",
    "\n",
    "def class_for_element(features, labels):\n",
    "    return labels\n",
    "\n",
    "# For training\n",
    "def train_input_fn():\n",
    "    dataset = combined_dataset(x_train, y_train)\n",
    "    return dataset.repeat().shuffle(500000).batch(200).prefetch(1)\n",
    "\n",
    "# For evaluation and metrics\n",
    "def eval_input_fn():\n",
    "    dataset = combined_dataset(x_test, y_test)\n",
    "    return dataset.batch(1000).prefetch(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the models\n",
    "#### Convolutional Model\n",
    "* The convolutional model is taken from: https://arxiv.org/pdf/1805.00794.pdf\n",
    "\n",
    "Model consists of:\n",
    "* An initial 1-D convolutional layer\n",
    "* 5 repeated residual blocks (`same` padding)\n",
    "* A fully-connected layer\n",
    "* A linear layer with softmax output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CNN_MODEL_DIR = './Models/CNN-Paper'\n",
    "\n",
    "def conv_unit(unit, input_layer):\n",
    "    s = '_' + str(unit)\n",
    "    layer = keras.Conv1D(name='Conv1' + s, filters=32, kernel_size=5, strides=1, padding='same', activation='relu')(input_layer)\n",
    "    layer = keras.Conv1D(name='Conv2' + s, filters=32, kernel_size=5, strides=1, padding='same', activation=None)(layer )\n",
    "    layer = keras.Add(name='ResidualSum' + s)([layer, input_layer])\n",
    "    layer = keras.Activation(\"relu\", name='Act' + s)(layer)\n",
    "    layer = keras.MaxPooling1D(name='MaxPool' + s, pool_size=5, strides=2)(layer)\n",
    "    return layer\n",
    "\n",
    "def cnn_model(input_layer, mode, params):\n",
    "    current_layer = keras.Conv1D(filters=32, kernel_size=5, strides=1)(input_layer)\n",
    "\n",
    "    for i in range(5):\n",
    "        current_layer = conv_unit(i + 1, current_layer)\n",
    "\n",
    "    current_layer = keras.Flatten()(current_layer)\n",
    "    current_layer = keras.Dense(32, name='FC1', activation='relu')(current_layer)\n",
    "    logits = keras.Dense(5, name='Output')(current_layer)\n",
    "    \n",
    "    print('Parameter count:', parameter_count())\n",
    "    return logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Recurrent Model\n",
    "\n",
    "Model consists of:\n",
    "* Two stacked bidirectional GRU layers\n",
    "* Two fully-connected layers\n",
    "* A linear layer with softmax output\n",
    "\n",
    "Since the model operates on segmented heartbeat samples, we can use a bidirectional RNN because the whole segment is available for processing at one time. It is also a more \\\"fair\\\" comparison with the CNN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "RNN_MODEL_DIR = './Models/RNN'\n",
    "RNN_OUTPUT_UNITS = [64, 128]\n",
    "\n",
    "def rnn_model(input_layer, mode, params):\n",
    "    current_layer = tf.keras.layers.Masking(mask_value=0., input_shape=(187, 1), name='Masked')(input_layer)\n",
    "    \n",
    "    for i, size in enumerate(RNN_OUTPUT_UNITS):\n",
    "        notLast = i + 1 < len(RNN_OUTPUT_UNITS)\n",
    "        layer = tf.keras.layers.GRU(size, return_sequences=notLast, dropout=0.2, name = 'GRU' + str(i+1))\n",
    "        current_layer = keras.Bidirectional(layer, name = 'BiRNN' + str(i+1))(current_layer)\n",
    "\n",
    "    current_layer = keras.Dense(64, name='Dense1', activation='relu')(current_layer)\n",
    "    current_layer = keras.Dense(16, name='Dense2', activation='relu')(current_layer)\n",
    "    logits = keras.Dense(5, name='Output', activation='relu')(current_layer)\n",
    "    \n",
    "    print('Parameter count:', parameter_count())\n",
    "    return logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Estimator setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initial learning rate\n",
    "INITIAL_LEARNING_RATE = 0.001\n",
    "\n",
    "# Learning rate decay per LR_DECAY_STEPS steps (1.0 = no decay)\n",
    "LR_DECAY_RATE = 0.5\n",
    "\n",
    "# Number of steps for LR to decay by LR_DECAY_RATE\n",
    "LR_DECAY_STEPS = 4000\n",
    "\n",
    "# Threshold for gradient clipping\n",
    "GRADIENT_NORM_THRESH = 10.0\n",
    "\n",
    "# Select model to train\n",
    "MODEL_DIR = CNN_MODEL_DIR\n",
    "MODEL_FN = cnn_model\n",
    "\n",
    "def classifier_fn(features, labels, mode, params):\n",
    "    is_training = mode == tf.estimator.ModeKeys.TRAIN\n",
    "    input_layer = tf.feature_column.input_layer(features, params['feature_columns'])\n",
    "    input_layer = tf.expand_dims(input_layer, -1)\n",
    "\n",
    "    logits = MODEL_FN(input_layer, mode, params)\n",
    "\n",
    "    # For prediction, exit here\n",
    "    predicted_classes = tf.argmax(logits, 1)\n",
    "    if mode == tf.estimator.ModeKeys.PREDICT:\n",
    "        predictions = {\n",
    "            'class_ids': predicted_classes[:, tf.newaxis],\n",
    "            'probabilities': tf.nn.softmax(logits),\n",
    "            'logits': logits,\n",
    "        }\n",
    "        return tf.estimator.EstimatorSpec(mode, predictions=predictions)\n",
    "\n",
    "    # For training and evaluation, compute the loss (MSE)\n",
    "    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)\n",
    "\n",
    "    accuracy = tf.metrics.accuracy(labels=labels, predictions=predicted_classes, name='acc_op')\n",
    "    metrics = {'accuracy': accuracy}\n",
    "    tf.summary.scalar('accuracy', accuracy[1])\n",
    "\n",
    "    if mode == tf.estimator.ModeKeys.EVAL:\n",
    "        return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)\n",
    "\n",
    "    # For training...\n",
    "    global_step = tf.train.get_global_step()\n",
    "    learning_rate = tf.train.exponential_decay(INITIAL_LEARNING_RATE, global_step, LR_DECAY_STEPS, LR_DECAY_RATE)\n",
    "\n",
    "    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n",
    "    #optimizer = tf.contrib.estimator.clip_gradients_by_norm(optimizer, GRADIENT_NORM_THRESH)\n",
    "    \n",
    "    train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())\n",
    "    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_columns = [tf.feature_column.numeric_column('time_series', [187])]\n",
    "\n",
    "estimator = tf.estimator.Estimator(\n",
    "    model_fn=classifier_fn,\n",
    "    model_dir=MODEL_DIR,\n",
    "    params={\n",
    "        'feature_columns': feature_columns,\n",
    "    })\n",
    "\n",
    "estimator.train(train_input_fn, steps=4000)\n",
    "info = estimator.evaluate(input_fn=eval_input_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sklearn.metrics as skm\n",
    "import seaborn\n",
    "\n",
    "dataset_fn = eval_input_fn\n",
    "\n",
    "predictions = estimator.predict(input_fn=dataset_fn)\n",
    "y_pred = []\n",
    "y_prob = []\n",
    "\n",
    "for i, value in enumerate(predictions):\n",
    "    class_id = value['class_ids']\n",
    "    y_pred.append(class_id)\n",
    "    probabilities = value['probabilities']\n",
    "    y_prob.append(probabilities[class_id])\n",
    "del predictions\n",
    "\n",
    "y_pred = np.array(y_pred)\n",
    "y_prob = np.array(y_prob)\n",
    "y_test = np.reshape(y_test, (len(y_test), 1))\n",
    "\n",
    "# Classification report\n",
    "report = skm.classification_report(y_test, y_pred)\n",
    "print(report)\n",
    "\n",
    "# Confusion matrix\n",
    "cm = skm.confusion_matrix(y_test, y_pred)\n",
    "seaborn.heatmap(cm, annot=True,annot_kws={\"size\": 16})\n",
    "\n",
    "y_prob_correct = y_prob[y_pred == y_test]\n",
    "y_prob_mis = y_prob[y_pred != y_test]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check probability estimates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from astropy.stats import binom_conf_interval\n",
    "\n",
    "_, _, _ = plt.hist(y_prob, 10, (0, 1))\n",
    "plt.xlabel('Belief')\n",
    "plt.ylabel('Count')\n",
    "plt.title('All Predictions')\n",
    "plt.show();\n",
    "\n",
    "n_all, bins = np.histogram(y_prob, 10, (0, 1))\n",
    "n_correct, bins = np.histogram(y_prob_correct, 10, (0, 1))\n",
    "\n",
    "f_correct = n_correct / np.clip(n_all, 1, None)\n",
    "f_bins = 0.5 * (bins[:-1] + bins[1:])\n",
    "\n",
    "n_correct = n_correct[n_all > 0]\n",
    "n_total = n_all[n_all > 0]\n",
    "f_correct = n_correct / n_total\n",
    "f_bins = f_bins[n_all > 0]\n",
    "\n",
    "lower_bound, upper_bound = binom_conf_interval(n_correct, n_total)\n",
    "error_bars = np.array([f_correct - lower_bound, upper_bound - f_correct])\n",
    "\n",
    "plt.plot([0., 1.], [0., 1.])\n",
    "plt.errorbar(f_bins, f_correct, yerr=error_bars, fmt='o')\n",
    "plt.xlabel('Softmax Probability')\n",
    "plt.ylabel('Frequency')\n",
    "plt.title('Correct Predictions')\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
